"""
Phase 6: Reconciliation with Fragility Scorecard
==================================================
Map which "collapses" and "sign reversals" from the fragility taxonomy
across earlier papers are now explained by the nonlinear framework.

Fragility categories:
- Robust: survives all samples
- Magnitude-attenuated: weakens in OECD/advanced
- Collapsed: goes to zero in some subsample
- Sign-reversed: flips sign
- Channel-reversed: different mechanism in different regimes

For each earlier finding, compute the implied effect from Phase 5's
coefficient surface and classify whether the framework explains it.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"

Z_VARS = ['Z_1', 'Z_2', 'Z_3']

# ── Earlier paper findings to reconcile ───────────────────────────────

FINDINGS = [
    # (Paper, DV, Full-panel result, OECD result, Framework prediction)
    {'paper': 'Paper 1 (Multilateral)',
     'dv': 'ca_gdp',
     'finding': 'Z₁→CA positive and significant in full panel',
     'oecd_finding': 'Weakens or loses significance in OECD-only',
     'sample': 'full'},

    {'paper': 'Paper 1 (Multilateral)',
     'dv': 'gross_savings_gdp',
     'finding': 'Z₁→S positive and significant',
     'oecd_finding': 'Collapses in OECD',
     'sample': 'full'},

    {'paper': 'Paper 1 (Multilateral)',
     'dv': 'gross_investment_gdp',
     'finding': 'Z₁→I positive but weaker than S',
     'oecd_finding': 'Near zero in OECD',
     'sample': 'full'},

    {'paper': 'Paper 3 (Safe Assets)',
     'dv': 'nfa_gdp',
     'finding': 'Safe issuers accumulate more negative NFA',
     'oecd_finding': 'Safe-issuer specific',
     'sample': 'safe_issuer'},

    {'paper': 'Paper 9A (Net-Gross)',
     'dv': 'ca_gdp',
     'finding': 'Income channel dominates in general panel',
     'oecd_finding': 'Trade channel may activate in open economies',
     'sample': 'full'},

    {'paper': 'Paper 10 (Trilemma)',
     'dv': 'ca_gdp',
     'finding': 'EMU amplifies CA deficits via trade channel',
     'oecd_finding': 'Channel switching confirmed',
     'sample': 'emu'},
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, p):
    return f"{val:.3f}{stars(p)}"


def run_gls(df, y_var, x_vars):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None
    gls = PanelGLS()
    try:
        gls.fit(sub[y_var].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None
    result = {'n_obs': gls.n_obs, 'r_squared': gls.r_squared}
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


MOD_SUFFIX = {
    'income_low': 'low', 'income_high': 'high', 'is_oecd': 'oecd',
    'safe_issuer': 'safe', 'eurozone': 'emu', 'kaopen_saturated': 'ksat',
}


def main():
    print("Phase 6: Reconciliation with Fragility Scorecard")
    print("=" * 70)

    df = pd.read_csv(DATA / "unified_panel.csv")

    # ── Run subsample tests ───────────────────────────────────────────
    # Full panel vs OECD-only vs EMU-only for each DV
    subsamples = {
        'Full panel': df,
        'OECD only': df[df['is_oecd'] == 1],
        'Non-OECD': df[df['is_oecd'] == 0],
        'High income': df[df['income_high'] == 1],
        'Low income': df[df['income_low'] == 1],
        'Safe issuer': df[df['safe_issuer'] == 1],
        'EMU': df[df['eurozone'] == 1],
    }

    dvs = ['ca_gdp', 'gross_savings_gdp', 'gross_investment_gdp', 'nfa_gdp']
    subsample_results = []

    for dv in dvs:
        if dv not in df.columns:
            continue
        for sname, sdf in subsamples.items():
            m = run_gls(sdf, dv, Z_VARS)
            if m is None:
                subsample_results.append({
                    'dv': dv, 'subsample': sname,
                    'Z1_coef': np.nan, 'Z1_p': np.nan,
                    'n_obs': 0, 'r2': np.nan
                })
                continue
            subsample_results.append({
                'dv': dv, 'subsample': sname,
                'Z1_coef': m['Z_1_coef'], 'Z1_p': m['Z_1_p'],
                'n_obs': m['n_obs'], 'r2': m['r_squared']
            })

    sub_df = pd.DataFrame(subsample_results)

    # ── Classify fragility ────────────────────────────────────────────
    def classify(full_coef, full_p, sub_coef, sub_p):
        if pd.isna(sub_coef):
            return 'Insufficient data'
        full_sig = full_p < 0.10
        sub_sig = sub_p < 0.10
        if not full_sig:
            if sub_sig:
                return 'Emergent (null→significant)'
            return 'Null in both'
        # Full is significant
        if not sub_sig:
            return 'COLLAPSED'
        if np.sign(full_coef) != np.sign(sub_coef):
            return 'SIGN-REVERSED'
        ratio = abs(sub_coef / full_coef) if full_coef != 0 else np.inf
        if ratio < 0.5:
            return 'Magnitude-attenuated'
        if ratio > 2.0:
            return 'Amplified'
        return 'Robust'

    # Build scorecard
    scorecard = []
    for dv in dvs:
        dv_sub = sub_df[sub_df['dv'] == dv]
        full = dv_sub[dv_sub['subsample'] == 'Full panel'].iloc[0]
        for _, row in dv_sub.iterrows():
            if row['subsample'] == 'Full panel':
                continue
            cat = classify(full['Z1_coef'], full['Z1_p'],
                          row['Z1_coef'], row['Z1_p'])
            scorecard.append({
                'dv': dv, 'subsample': row['subsample'],
                'full_Z1': full['Z1_coef'], 'full_p': full['Z1_p'],
                'sub_Z1': row['Z1_coef'], 'sub_p': row['Z1_p'],
                'n': row['n_obs'], 'classification': cat
            })

    sc_df = pd.DataFrame(scorecard)

    # ── Print summary ─────────────────────────────────────────────────
    print("\nFragility Scorecard:")
    print("-" * 70)
    for _, row in sc_df.iterrows():
        full_str = fmt(row['full_Z1'], row['full_p'])
        sub_str = fmt(row['sub_Z1'], row['sub_p']) if not pd.isna(row['sub_Z1']) else 'N/A'
        print(f"  {row['dv']:25s} {row['subsample']:15s}: "
              f"{full_str:>12s} → {sub_str:>12s}  [{row['classification']}]")

    # ── Now check: does the interaction model PREDICT these patterns? ──
    print("\n\nReconciliation: Does the interaction model predict the collapses?")
    print("-" * 70)

    # Read Phase 5 surface results
    phase2_df = pd.read_csv(DATA / "phase2_varying_coefficients.csv")

    reconciled = []
    for _, row in sc_df.iterrows():
        if row['classification'] in ['COLLAPSED', 'SIGN-REVERSED', 'Magnitude-attenuated']:
            # Check if Phase 2 found a significant interaction for this DV×subsample mapping
            mod_map = {
                'OECD only': 'is_oecd', 'Non-OECD': None,
                'High income': 'income_high', 'Low income': 'income_low',
                'Safe issuer': 'safe_issuer', 'EMU': 'eurozone',
            }
            mod = mod_map.get(row['subsample'])
            if mod is None:
                reconciled.append({**row, 'explained': 'No direct moderator'})
                continue

            p2_match = phase2_df[(phase2_df['dv'] == row['dv']) &
                                 (phase2_df['moderator'] == mod)]
            if p2_match.empty:
                reconciled.append({**row, 'explained': 'Not tested'})
                continue

            p2_row = p2_match.iloc[0]
            if p2_row['Z1x_p'] < 0.10:
                reconciled.append({**row, 'explained': f"YES (Z₁×{mod} p={p2_row['Z1x_p']:.3f})"})
                print(f"  EXPLAINED: {row['dv']} {row['subsample']} "
                      f"[{row['classification']}] — Z₁×{mod} p={p2_row['Z1x_p']:.3f}")
            else:
                reconciled.append({**row, 'explained': f"No (p={p2_row['Z1x_p']:.3f})"})
                print(f"  NOT explained: {row['dv']} {row['subsample']} "
                      f"[{row['classification']}] — Z₁×{mod} p={p2_row['Z1x_p']:.3f}")

    # ── Write outputs ─────────────────────────────────────────────────
    print(f"\n{'=' * 70}")
    print("Writing output tables...")

    # Table 1: Full subsample scorecard
    with open(OUT_TABLES / "phase6_fragility_scorecard.md", 'w') as f:
        f.write("# Phase 6: Fragility Scorecard\n\n")
        f.write("Z₁ coefficient across subsamples, classified by stability.\n\n")
        f.write("| DV | Subsample | Z₁ (full) | Z₁ (sub) | N | Classification |\n")
        f.write("|---|---|---|---|---|---|\n")
        for _, row in sc_df.iterrows():
            full = fmt(row['full_Z1'], row['full_p'])
            sub = fmt(row['sub_Z1'], row['sub_p']) if not pd.isna(row['sub_Z1']) else 'N/A'
            f.write(f"| {row['dv']} | {row['subsample']} | {full} | "
                    f"{sub} | {int(row['n'])} | {row['classification']} |\n")
        f.write("\n*Classification: Robust (same sign, similar magnitude), "
                "Magnitude-attenuated (<50% of full), "
                "COLLAPSED (significant→null), SIGN-REVERSED (flips sign).*\n")
    print("  Wrote: phase6_fragility_scorecard.md")

    # Table 2: Reconciliation
    rec_df = pd.DataFrame(reconciled)
    if not rec_df.empty:
        with open(OUT_TABLES / "phase6_reconciliation.md", 'w') as f:
            f.write("# Phase 6: Reconciliation — Which Fragilities Are Explained?\n\n")
            f.write("| DV | Subsample | Fragility | Explained by Framework? |\n")
            f.write("|---|---|---|---|\n")
            for _, row in rec_df.iterrows():
                f.write(f"| {row['dv']} | {row['subsample']} | "
                        f"{row['classification']} | {row['explained']} |\n")

            n_explained = sum(1 for _, r in rec_df.iterrows() if r['explained'].startswith('YES'))
            n_total = len(rec_df)
            f.write(f"\n*{n_explained}/{n_total} fragile findings explained by "
                    f"the nonlinear framework.*\n")
        print("  Wrote: phase6_reconciliation.md")

    # Summary counts
    cats = sc_df['classification'].value_counts()
    print(f"\nSummary:")
    for cat, count in cats.items():
        print(f"  {cat}: {count}")
    if not rec_df.empty:
        n_exp = sum(1 for _, r in rec_df.iterrows() if r['explained'].startswith('YES'))
        print(f"\n  Explained by framework: {n_exp}/{len(rec_df)} fragile findings")

    print(f"\nPhase 6 complete.")


if __name__ == '__main__':
    main()
